{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "import accelerate\n",
    "import peft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_checkpoint = \"google/vit-base-patch16-224-in21k\"  # pre-trained model from which to fine-tune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ea293c08a90446a181962959fe13ab1c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "dataset = load_dataset(\"food101\", split=\"train[:5000]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['image', 'label'],\n",
      "    num_rows: 5000\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "print(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = dataset.features[\"label\"].names\n",
    "# print(labels)\n",
    "label2id, id2label = dict(), dict()\n",
    "for i, label in enumerate(labels):\n",
    "    label2id[label] = i\n",
    "    id2label[i] = label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "ViTImageProcessor {\n",
       "  \"do_normalize\": true,\n",
       "  \"do_rescale\": true,\n",
       "  \"do_resize\": true,\n",
       "  \"image_mean\": [\n",
       "    0.5,\n",
       "    0.5,\n",
       "    0.5\n",
       "  ],\n",
       "  \"image_processor_type\": \"ViTImageProcessor\",\n",
       "  \"image_std\": [\n",
       "    0.5,\n",
       "    0.5,\n",
       "    0.5\n",
       "  ],\n",
       "  \"resample\": 2,\n",
       "  \"rescale_factor\": 0.00392156862745098,\n",
       "  \"size\": {\n",
       "    \"height\": 224,\n",
       "    \"width\": 224\n",
       "  }\n",
       "}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoImageProcessor\n",
    "\n",
    "image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)\n",
    "image_processor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.transforms import (\n",
    "    CenterCrop,\n",
    "    Compose,\n",
    "    Normalize,\n",
    "    RandomHorizontalFlip,\n",
    "    RandomResizedCrop,\n",
    "    Resize,\n",
    "    ToTensor,\n",
    ")\n",
    "\n",
    "normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)\n",
    "train_transforms = Compose(\n",
    "    [\n",
    "        RandomResizedCrop(image_processor.size[\"height\"]),\n",
    "        RandomHorizontalFlip(),\n",
    "        ToTensor(),\n",
    "        normalize,\n",
    "    ]\n",
    ")\n",
    "\n",
    "val_transforms = Compose(\n",
    "    [\n",
    "        Resize(image_processor.size[\"height\"]),\n",
    "        CenterCrop(image_processor.size[\"height\"]),\n",
    "        ToTensor(),\n",
    "        normalize,\n",
    "    ]\n",
    ")\n",
    "\n",
    "\n",
    "def preprocess_train(example_batch):\n",
    "    \"\"\"Apply train_transforms across a batch.\"\"\"\n",
    "    example_batch[\"pixel_values\"] = [train_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]]\n",
    "    return example_batch\n",
    "\n",
    "\n",
    "def preprocess_val(example_batch):\n",
    "    \"\"\"Apply val_transforms across a batch.\"\"\"\n",
    "    example_batch[\"pixel_values\"] = [val_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]]\n",
    "    return example_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# split up training into training + validation\n",
    "splits = dataset.train_test_split(test_size=0.1)\n",
    "train_ds = splits[\"train\"]\n",
    "val_ds = splits[\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds.set_transform(preprocess_train)\n",
    "val_ds.set_transform(preprocess_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Prints the number of trainable parameters in the model.\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    print(\n",
    "        f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "ViTForImageClassification(\n",
       "  (vit): ViTModel(\n",
       "    (embeddings): ViTEmbeddings(\n",
       "      (patch_embeddings): ViTPatchEmbeddings(\n",
       "        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n",
       "      )\n",
       "      (dropout): Dropout(p=0.0, inplace=False)\n",
       "    )\n",
       "    (encoder): ViTEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0-11): 12 x ViTLayer(\n",
       "          (attention): ViTSdpaAttention(\n",
       "            (attention): ViTSdpaSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.0, inplace=False)\n",
       "            )\n",
       "            (output): ViTSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.0, inplace=False)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): ViTIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (intermediate_act_fn): GELUActivation()\n",
       "          )\n",
       "          (output): ViTOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (dropout): Dropout(p=0.0, inplace=False)\n",
       "          )\n",
       "          (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "          (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
       "  )\n",
       "  (classifier): Linear(in_features=768, out_features=101, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoModelForImageClassification, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForImageClassification.from_pretrained(\n",
    "    model_checkpoint,\n",
    "    label2id=label2id,\n",
    "    id2label=id2label,\n",
    "    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n",
    ")\n",
    "# print_trainable_parameters(model)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 667493 || all params: 86543818 || trainable%: 0.77\n"
     ]
    }
   ],
   "source": [
    "from peft import LoraConfig, get_peft_model\n",
    "\n",
    "config = LoraConfig(\n",
    "    r=16,\n",
    "    lora_alpha=16,\n",
    "    target_modules=[\"query\", \"value\"],\n",
    "    lora_dropout=0.1,\n",
    "    bias=\"none\",\n",
    "    modules_to_save=[\"classifier\"],\n",
    ")\n",
    "lora_model = get_peft_model(model, config)\n",
    "print_trainable_parameters(lora_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "\n",
    "model_name = model_checkpoint.split(\"/\")[-1]\n",
    "batch_size = 128\n",
    "\n",
    "args = TrainingArguments(\n",
    "    f\"{model_name}-finetuned-lora-food101\",\n",
    "    remove_unused_columns=False,\n",
    "    eval_strategy=\"epoch\",\n",
    "    save_strategy=\"epoch\",\n",
    "    learning_rate=5e-3,\n",
    "    per_device_train_batch_size=batch_size,\n",
    "    gradient_accumulation_steps=4,\n",
    "    per_device_eval_batch_size=batch_size,\n",
    "    fp16=True,\n",
    "    num_train_epochs=5,\n",
    "    logging_steps=10,\n",
    "    load_best_model_at_end=True,\n",
    "    metric_for_best_model=\"accuracy\",\n",
    "    push_to_hub=True,\n",
    "    label_names=[\"labels\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "\n",
    "\n",
    "# the compute_metrics function takes a Named Tuple as input:\n",
    "# predictions, which are the logits of the model as Numpy arrays,\n",
    "# and label_ids, which are the ground-truth labels as Numpy arrays.\n",
    "def compute_metrics(eval_pred):\n",
    "    \"\"\"Computes accuracy on a batch of predictions\"\"\"\n",
    "    predictions = np.argmax(eval_pred.predictions, axis=1)\n",
    "    return metric.compute(predictions=predictions, references=eval_pred.label_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def collate_fn(examples):\n",
    "    pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n",
    "    labels = torch.tensor([example[\"label\"] for example in examples])\n",
    "    return {\"pixel_values\": pixel_values, \"labels\": labels}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\19895\\AppData\\Local\\conda\\conda\\envs\\tg\\lib\\site-packages\\accelerate\\accelerator.py:494: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n",
      "  self.scaler = torch.cuda.amp.GradScaler(**kwargs)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b7c23601ba246aebd325ef408f8f0b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/45 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\Users\\19895\\AppData\\Local\\conda\\conda\\envs\\tg\\lib\\site-packages\\transformers\\models\\vit\\modeling_vit.py:252: UserWarning: 1Torch was not compiled with flash attention. (Triggered internally at C:\\actions-runner\\_work\\pytorch\\pytorch\\builder\\windows\\pytorch\\aten\\src\\ATen\\native\\transformers\\cuda\\sdp_utils.cpp:555.)\n",
      "  context_layer = torch.nn.functional.scaled_dot_product_attention(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "85cbdb4de45c49878b9e8c054892370a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.5615572333335876, 'eval_accuracy': 0.868, 'eval_runtime': 37.8138, 'eval_samples_per_second': 13.223, 'eval_steps_per_second': 0.106, 'epoch': 1.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "'(ProtocolError('Connection aborted.', TimeoutError(10060, '由于连接方在一段时间后没有正确答复或连接的主机没有反应，连接尝试失败。', None, 10060, None)), '(Request ID: e93a87aa-325e-476d-9576-42901da42ff3)')' thrown while requesting PUT https://hf-hub-lfs-us-east-1.s3-accelerate.amazonaws.com/repos/3f/b6/3fb64b994b7e9c94ae8373328ec555886624ccb4ef3bf137728360a6b8a0bef8/7a0ecadb8a10cf24e0a7ebf2d1ddcb52d269c7de199e9cd84fdd5e2760a85e0b?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Content-Sha256=UNSIGNED-PAYLOAD&X-Amz-Credential=AKIA2JU7TKAQLC2QXPN7%2F20240925%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20240925T064653Z&X-Amz-Expires=900&X-Amz-Signature=7b31e240b857f26b8ba43adb16a06b5f8af9ae34eec8698784e9f5e42fe765fd&X-Amz-SignedHeaders=host&x-amz-storage-class=INTELLIGENT_TIERING&x-id=PutObject\n",
      "Retrying in 1s [Retry 1/5].\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'loss': 2.1801, 'grad_norm': 0.4545424282550812, 'learning_rate': 0.003888888888888889, 'epoch': 1.11}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2122a4775c214dacbc98a42d7fd421a7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.1696804314851761, 'eval_accuracy': 0.938, 'eval_runtime': 35.2077, 'eval_samples_per_second': 14.201, 'eval_steps_per_second': 0.114, 'epoch': 2.0}\n",
      "{'loss': 0.3292, 'grad_norm': 0.14382442831993103, 'learning_rate': 0.002777777777777778, 'epoch': 2.22}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "304774fc61614f338169a27176bc1c4b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.11115420609712601, 'eval_accuracy': 0.962, 'eval_runtime': 38.2923, 'eval_samples_per_second': 13.057, 'eval_steps_per_second': 0.104, 'epoch': 3.0}\n",
      "{'loss': 0.1956, 'grad_norm': 0.15067866444587708, 'learning_rate': 0.0016666666666666666, 'epoch': 3.33}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7357431d3ff442e5b6ccb56a53a8213e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.10155972838401794, 'eval_accuracy': 0.966, 'eval_runtime': 38.1734, 'eval_samples_per_second': 13.098, 'eval_steps_per_second': 0.105, 'epoch': 4.0}\n",
      "{'loss': 0.1571, 'grad_norm': 0.20834752917289734, 'learning_rate': 0.0005555555555555556, 'epoch': 4.44}\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e23aae8154c4a738ee7a9622ec2d1ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'eval_loss': 0.0965295284986496, 'eval_accuracy': 0.968, 'eval_runtime': 39.0391, 'eval_samples_per_second': 12.808, 'eval_steps_per_second': 0.102, 'epoch': 5.0}\n",
      "{'train_runtime': 4223.0665, 'train_samples_per_second': 5.328, 'train_steps_per_second': 0.011, 'train_loss': 0.655164533191257, 'epoch': 5.0}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No files have been modified since last commit. Skipping to prevent empty commit.\n"
     ]
    }
   ],
   "source": [
    "trainer = Trainer(\n",
    "    lora_model,\n",
    "    args,\n",
    "    train_dataset=train_ds,\n",
    "    eval_dataset=val_ds,\n",
    "    tokenizer=image_processor,\n",
    "    compute_metrics=compute_metrics,\n",
    "    data_collator=collate_fn,\n",
    ")\n",
    "train_results = trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a22d3faa5764ba5b77694a721cd6c69",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.0965295284986496,\n",
       " 'eval_accuracy': 0.968,\n",
       " 'eval_runtime': 35.2283,\n",
       " 'eval_samples_per_second': 14.193,\n",
       " 'eval_steps_per_second': 0.114,\n",
       " 'epoch': 5.0}"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainer.evaluate(val_ds)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tg",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
